#############################################################################
# Jackknife bias correction 
#
# 1. No trimming case
# 2. Trimming 
# 
#############################################################################

rm(list=ls())
graphics.off()

library(data.table)
library(ggplot2)
library(haven)
library(fixest)
library("RColorBrewer")
library(xtable)
library(Hmisc) 
library(quantreg) # For rearranging CDF

load("Data_new/FE_results.RData")
source("Code/0. CODE Auxiliary.R")
source("Code/4a. CODE Jackknife version 2.R")
save_dta = "Data_new"
save_res = "Results"
time0=Sys.time()

#############################################################################
# 1. Apply jackknife (no trimming case)
#############################################################################

nsplit=5
sseed=20231221

# Baseline model
J0=jack2_split(data=dt,trim=0,nsplit=nsplit,sseed=sseed)

# Compare the overall results with the ones that we already have
summary(J0$dt_jack[,.(mg0)])
summary(dt_res[,.(mg3)])

# Subset only some variables 
dt_jack = J0$dt_jack
dt_jack = merge(dt_jack,dt_res[,.(fips,year,ddr,mg3)],by=c("fips","year"),all=TRUE)
setnames(dt_jack,"mg3","mg_fe")
dim(dt_jack)
summary(dt_jack)
summary(dt_jack[,.(mg0-mg_fe)])
dt_jack[,mg0:=NULL]

# Jackknife version of the CDF of marginal effects at the county-year level
mat_cdf = jack2_cdf(data=dt_jack,bval=45,nsplit=nsplit) 

# Simulated data from jackknife estimate of the CDF
pdf_sim = jack2_sim(cdf_data=mat_cdf,sseed=sseed)
summary(pdf_sim)

save.image("Data_new/JK2_results.RData")

#############################################################################
# 2. Robustness to trimming 
#############################################################################

trim_vals=c(0,0.01,0.02,0.05,0.08,seq(0.1,0.5,0.05))

# Matrix for saving results: trim percentage, ncounties, mean, var
Res = matrix(NA,ncol=9,nrow=1)
colnames(Res)=c("Percentage Trim ","N obs","N counties","Mean","Mean form","Mean sim","Var","Var for","Var sim")

for (j in seq(1,length(trim_vals),1)) {
  
  print(j)

  # Jackknife after trimming
  Jnow=jack2_split(data=dt,trim=trim_vals[j],nsplit=nsplit,sseed=sseed)
  
  #  
  dt_now = Jnow$dt_jack
  setnames(dt_now,"mg0","mg_fe")
  
  # Jackknife version of the CDF of marginal effects at the county-year level
  cdf_now = jack2_cdf(data=dt_now,bval=45,nsplit=nsplit) 
  
  # Simulated data from jackknife estimate of the CDF
  pdf_now = jack2_sim(cdf_data=cdf_now,sseed=sseed)
  
  Rnow = c(trim_vals[j],Jnow$nobs,Jnow$nc,
           Jnow$R["Mean","tot"],Jnow$R["Mean","corr"],mean(pdf_now$cdf_mont,na.rm=TRUE),
           Jnow$R["Var","tot"],Jnow$R["Var","corr"],var(pdf_now$cdf_mont,na.rm=TRUE))
  Res = rbind(Res,Rnow)  
  
  if (j%%3 == 0) {
    save.image("Data_new/JK2_results.RData")
  }
}

Res=Res[-1,]
Res

# Tabulate
#dig=c(0,2,0,0,rep(3,6))
#print(xtable(Res,type="latex",digits=dig,display=c("s",rep("f",9))),hline.after = NULL,
#      file=paste0(save_res,"/Table_FE_trim.tex"),include.rownames = FALSE,include.colnames=FALSE,
#      sanitize.text.function = function(x){x},only.contents = TRUE)

# Results in data table
dt_plot = data.table(Res)
setnames(dt_plot,c("trim","nobs","nc","mean_ori","mean_form","mean_corr","var_ori","var_form","var_corr"))
summary(dt_plot)

# Plots
g=ggplot(dt_plot,aes(x=trim)) + theme_bw() +
  theme(title=element_text(size=14),plot.title=element_text(hjust=0.5),
        plot.subtitle=element_text(hjust=0.5), 
        axis.text = element_text(size=18),axis.title = element_text(size=18),
        panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        legend.position = 'none') +
  xlab("Sample proportion trimmed") 

# Mean 
g + geom_line(aes(y=mean_ori),linetype="dashed") +
  geom_line(aes(y=mean_corr)) + 
  geom_line(aes(y=mean_form),linetype="dotted") + ylab("Mean") + ylim(c(-10,-5))
ggsave(paste0(save_res,"/FE_trim_mean.pdf"))

# Variance
g+ geom_line(aes(y=var_ori),linetype="dashed") + 
  geom_line(aes(y=var_corr)) + 
  geom_line(aes(y=var_form),linetype="dotted") + ylab("Variance") #+ ylim(c(0,30))
ggsave(paste0(save_res,"/FE_trim_var.pdf"))

rm(Res,Jnow) 
time1=Sys.time()
save.image("Data_new/JK2_results.RData")
